#!/usr/bin/env python3

import functools
import itertools
import os
import random
from collections import defaultdict, deque
from copy import deepcopy
from pathlib import Path
from typing import Callable, Dict, List, Optional, Sequence, Tuple

import numpy as np
import torch
import wandb
from rpi import logger
from rpi.agents.base import Agent
from rpi.agents.mamba import (ActivePolicySelector,ActivePolicySelectorPlus, ActiveStateExplorer,
                                MaxValueFn, MaxValueFnPlus,ValueEnsemble)
from rpi.agents.ppo import update_critic_ensemble
from rpi.agents.mamba import StatePredictorEnsemble
from rpi.state_detection.ensemble_network import NextStatePredActiveStateExplorer,StatePredEnsembleNewStateDetector,EuclainDistanceActiveStateExplorer,WassersteinDistanceActiveStateExplorer
from rpi.state_detection.distance_based_methods import EuclideanDistance, WassersteinDistance
from rpi.agents.ppo import PPOAgent
from rpi.helpers import set_random_seed, to_torch
from rpi.helpers.data import flatten
from rpi.helpers.env import rollout_single_ep, rollout, roll_in_and_out_lops, roll_in_and_out_mamba, rollout_from_state
from rpi.helpers.initializers import ortho_init
from rpi.nn.empirical_normalization import EmpiricalNormalization
from rpi.policies import (GaussianHeadWithStateIndependentCovariance,
                            SoftmaxCategoricalHead)
from rpi.scripts.sweep.default_args import Args
from rpi.value_estimations import (
    _attach_advantage_and_value_target_to_episode,
    _attach_log_prob_to_episodes, _attach_return_and_value_target_to_episode,
    _attach_value_to_episodes, _attach_mean_return_and_value_target_to_episode)
from torch import nn

from rpi.helpers.anonymousslurm import upload_slurm_logs
from rpi.helpers.factory import Factory

from rpi.agents.expert import ExpertAgent
from rpi.agents.mamba import RPIAgent


def _make_env(env_name='DartCartPole-v1', test=False, default_seed=0):
    """
    Environments we want to try:
    - dmc:Cheetah-run-v1
    - dmc:Ant-walk-v1
    - DartCartPole-v1
    - DardDoubleInvertedPendulum-v1  <-- not tested yet
    """
    from rpi.helpers import env
    seed = default_seed if not test else 42 - default_seed
    fn = env.make_env
    extra_kwargs = {}
    if env_name.startswith('dmc'):
        extra_kwargs = {'task_kwargs': {'random': seed}}
    elif env_name.startswith('minigrid'):
        env_name = env_name.split(':')[-1]
        fn = env.make_minigrid_env  # Call make_maze_env function!
    return fn(env_name, seed=seed, **extra_kwargs)


remove_cond = lambda key, val: 'statepred' in key.split('/')[0]
def wandb_log(log: dict):
    """Shallow wrapper to filter some entries"""
    wandb.log({key: val for key, val in log.items() if not remove_cond(key, val)})


def load_ref_states(env_spec_id) -> Dict[int, List[np.ndarray]]:
    path = Path(Args.saved_states_dir) / f'{env_spec_id}.pt'
    samples = torch.load(path)
    return samples


def should_flush(episodes, curr_episode, buffer_size):
    dataset_size = (
        sum(len(episode) for episode in episodes)
        + len(curr_episode)
    )

    return dataset_size >= buffer_size


def tolist(tensor: torch.Tensor):
    if len(tensor.shape) == 0:
        return [tensor.item()]
    return tensor.tolist()


class SwitchingTimeSampler:
    """Sample a switching time, optionally with moving average
    """
    def __init__(self, time_limit, use_moving_ave=True) -> None:
        print('use_moving_ave:', use_moving_ave)
        self.time_limit = time_limit
        self.use_moving_ave = use_moving_ave
        self.pol_mv_avg = PolMvAvg(None)

    def update_pma(self, val):
        """Feed mean learner episode lengths"""
        self.pol_mv_avg.update(val)

    def sample(self):
        if self.use_moving_ave:
            t_switch, scale = self._geometric_t(self.pol_mv_avg.val)
        else:
            t_switch = random.randint(0, self.time_limit - 1)

        return t_switch

    def _geometric_t(self, mean, gamma=1.0):
        """Copied from microsoft/mamba"""
        prob = 1 / mean
        t_switch = np.random.geometric(prob)  # starts from 1
        if t_switch > self.time_limit - 1:
            t_switch = self.time_limit - 1
            p = (1 - prob) ** t_switch  # tail probability
        else:
            p = (1 - prob) ** (t_switch - 1) * prob
        prob, scale = self._compute_prob_and_scale(t_switch, gamma)

        return t_switch, prob / p

    def _compute_prob_and_scale(self, t, gamma):
        """Treat the weighting in a problem as probability. Compute the
        probability for a time step and the sum of the weights.

        For the objective below,
            \sum_{t=0}^{T-1} \gamma^t c_t
        where T is finite and \gamma in [0,1], or T is infinite and gamma<1.
        It computes
            scale = \sum_{t=0}^{T-1} \gamma^t
            prob = \gamma^t / scale

        Copied from microsoft/mamba
        """
        assert t <= self.time_limit - 1
        if self.time_limit < float("Inf"):
            p0 = gamma ** np.arange(self.time_limit)
            sump0 = np.sum(p0)
            prob = p0[t] / sump0
        else:
            raise NotImplementedError('I have no idea what to do')
        return prob, sump0


class PolMvAvg:
    """An estimator based on polynomially weighted moving average.

    The estimate after N calls is computed as
        val = \sum_{n=1}^N n^power x_n / nor_N
    where nor_N is equal to \sum_{n=1}^N n^power, and power is a parameter.

    Copied from microsoft/mamba
    """

    def __init__(self, val, power=0, weight=0.0):
        self._val = val * weight if val is not None else 0.0
        self._nor = weight
        self.power = power
        self._itr = 1

    def update(self, val, weight=1.0):
        self._val = self.mvavg(self._val, val * weight, self.power)
        self._nor = self.mvavg(self._nor, weight, self.power)
        self._itr += 1

    def mvavg(self, old, new, power):
        return old + new * self._itr ** power

    @property
    def val(self):
        return self._val / np.maximum(1e-8, self._nor)


def get_ref_stateacts(env, max_episode_len: int) -> List[Tuple[np.ndarray, np.ndarray]]:
    """Obtain state-action pairs by running a random policy on the environment
    """
    # Generate states from random policy
    randpol_stateacts = []
    for target_step in [3, 5, 10, 12, 14]:
        completed = False
        while not completed:
            random_policy = lambda obs: env.action_space.sample()
            episode = rollout_single_ep(env, random_policy, max_episode_len)
            if len(episode) >= target_step:
                randpol_stateacts.append((episode[-1]['state'], episode[-1]['action']))
                completed = True

    return randpol_stateacts

#NOTE X: collect num_rollouts trajectories for each expert.
def collect_expert_episodes(env, experts: List[Agent], max_episode_len: int, num_rollouts: int,
                            tgtval_method: str, learner_gamma: float, learner_lmd: float):
    """Roll out experts to collect transitions, and then attach value & value target to transitions (thus requires learner gamma and lambda)"""

    expert2episodes = [[] for _ in experts]
    for _ in range(num_rollouts):
        for expert_idx, expert in enumerate(experts):
            episode = rollout_single_ep(env, functools.partial(expert.act, mode=Args.deterministic_experts), max_episode_len)
            expert2episodes[expert_idx].append(episode)

    return expert2episodes


def update_expert_vfn(experts, expert_rollouts, num_epochs, batch_size: int, num_val_iterations: int,
                      tgtval_method: str, learner_gamma: float, learner_lmd: float, learner:Agent=None) -> List[List[dict]]:
    logs_per_expert = [[] for _ in experts]

    if num_epochs == 0 or num_val_iterations == 0:
        return logs_per_expert

    for expert_idx, expert in enumerate(experts):
        expert_k_transitions = flatten(expert_rollouts[expert_idx])

        if len(expert_k_transitions) < 2:
            print(f'roll out by the expert {expert_idx} was too short!')
            print('len(expert_k_transitions)', len(expert_k_transitions))
            continue

        itr = 0
        for i in range(num_val_iterations):
            # print(f'updating critic ensemble {i + 1}/{Args.pret_num_val_iterations}')
            if learner is not None:
                _attach_value_to_episodes(experts[expert_idx].vfn, expert_k_transitions, obs_normalizer=experts[expert_idx].obs_normalizer,learner_vfn=functools.partial(learner.vfn,normalize_input=True))
            else:
                _attach_value_to_episodes(experts[expert_idx].vfn, expert_k_transitions, obs_normalizer=experts[expert_idx].obs_normalizer)

            # Recompute the target value and attach them if necessary
            if tgtval_method == 'monte-carlo':
                for episode in expert_rollouts[expert_idx]:
                    _attach_return_and_value_target_to_episode(episode, learner_gamma)
            elif tgtval_method == 'monte-carlo-bootstrap':
                for episode in expert_rollouts[expert_idx]:
                    _attach_return_and_value_target_to_episode(episode, learner_gamma, bootstrap=(tgtval_method=='monte-carlo-bootstrap'))
            elif tgtval_method == 'gae':
                for episode in expert_rollouts[expert_idx]:
                    if learner is not None:
                        _attach_advantage_and_value_target_to_episode(episode, learner_gamma, learner_lmd,algorithm=Args.algorithm,leaner=True)
                    else:
                        _attach_advantage_and_value_target_to_episode(episode, learner_gamma, learner_lmd,algorithm=Args.algorithm)
                        
            elif tgtval_method == 'avg-monte-carlo':
                for episode in expert_rollouts[expert_idx]:
                    _attach_mean_return_and_value_target_to_episode(episode)
            else:
                raise ValueError()

            # NOTE: num_updates may change the behavior quite a lot.
            _, loss_critic_history = update_critic_ensemble(expert, expert_k_transitions, num_epochs=max(1, num_epochs // num_val_iterations), batch_size=batch_size, std_from_means=True)  # 100 for CartPole, DIP
            for loss in loss_critic_history:
                logs_per_expert[expert_idx].append({
                    f'loss-critic-{expert_idx}': loss,
                    f'num-transitions-{expert_idx}': len(expert_k_transitions),
                    'vi-step': itr,
                })
                itr += 1
    return logs_per_expert


def evaluate_expert_on_learner_distr(env, learner, experts, max_episode_len: int, num_states: int = 100):
    """
    Version 1:
    1. Sample N transitions with learner
    2. Evaluate every single state with every expert

    Version 2:
    1. Sample switching timestep with swtime_sampler
    2. Repeat:
        1. roll in up to switching timestep t_sw by learner -> continue if episode terminates before reaching t_sw
        2. evaluate each expert at s_{t_sw}
        3. (later) roll out "each" expert from s_{t_sw} <-- This requires saving & restoring state in the simulator!!
    """

    switching_times = [random.randint(1, max_episode_len - 1) for _ in range(num_states)]
    observations, sim_states = [], []
    num_videos = 5
    for i, switching_time in enumerate(switching_times):
        ep = rollout_single_ep(env, learner.act, max_episode_len=switching_time, save_sim_state=True, save_video=(i < num_videos))
        observations.append(ep[-1]['state'])
        sim_states.append(ep[-1]['sim_state'])

        if 'frame' in ep[0]:
            # Save video on wandb
            print('Saving a video to wandb...')
            frames = np.asarray([tr['frame'] for tr in ep], dtype=np.uint8)
            wandb.log({f'probe/video-state-{i:02d}': wandb.Video(frames.transpose(0, 3, 1, 2), fps=20, format='mp4')})

    observations = to_torch(observations)
    # _transitions = flatten(_episodes)
    # _states = to_torch([tr['state'] for tr in _transitions])
    # _sim_states = [tr['sim_state'] for tr in _transitions]

    # Evaluate expert value functions on all _states
    with torch.no_grad():
        expert2values = defaultdict(list)
        expert2stds = defaultdict(list)
        for expert in experts:
            distrs = expert.vfn.forward_all(observations, normalize_input=True)
            values = torch.stack([distr.mean for distr in distrs], dim=0)
            expert2values[expert.name].append(values.mean().item())
            expert2stds[expert.name].append(values.std().item())

    # Evaluate states by running each expert from _sim_state
    with torch.no_grad():
        expert2retvalues = defaultdict(list)
        for i, (sim_state, obs) in enumerate(zip(sim_states, observations)):
            print(f'rolling out expert {i} / {len(sim_states)}')
            for expert in experts:
                # NOTE: This runs a new episode with horizon max_episode_len from sim_state

                ep = rollout_from_state(env, expert.act, sim_state, obs, max_episode_len, save_video=(i < num_videos))
                retval = sum([tr['reward'] for tr in ep])
                expert2retvalues[expert.name].append(retval)

                if 'frame' in ep[0]:
                    # Save video on wandb
                    print('saving a video to wandb...')
                    frames = np.asarray([tr['frame'] for tr in ep], dtype=np.uint8)
                    wandb.log({f'probe/video-state-{i:02d}-exp{expert.name}': wandb.Video(frames.transpose(0, 3, 1, 2), fps=20, format='mp4')})

    wandb_log = {
        **{f'probe/learner-states-avgval-expert-{step}': np.mean(meanval) for step, meanval in expert2values.items()},
        **{f'probe/learner-states-avgstd-expert-{step}': np.mean(meanstd) for step, meanstd in expert2stds.items()},
        **{f'probe/learner-states-retval-expert-{step}': np.mean(retvals) for step, retvals in expert2retvalues.items()},
        **{f'probe/learner-states-retval-expert-hist-{step}': wandb.Histogram(retvals) for step, retvals in expert2retvalues.items()},
    }
    return wandb_log


def roll_in_and_out(env, learner: Agent, experts: Sequence[Agent], swtime_sampler: SwitchingTimeSampler, num_rollouts: int, gamma: float, lmd: float, max_episode_len: int, return_wandblogs: bool = True, ase_sigma: Optional[float] = None, expert2episodes: Optional[Callable] = None , switching_state_callback: Optional[Callable] = None,itr=None,num_train_steps=None):
    assert Args.algorithm != 'pg-gae'

    all_episodes = []
    switch_times = []
    ro_expert_inds = []

    log_expidx = []
    log_switch_time = []
    log_expert_traj_len = []
    log_switch_valmeans = []
    log_switch_valstds = []
    log_uncertainty_max = []
    log_uncertainty_min = []
    log_ase_sigma = []

    for _ in range(num_rollouts):
        # Roll-in and out with each "switching" policy and oracle selection policy
        if Args.algorithm == 'lops-aps-ase':
            if Args.ase_uncertainty == 'value_std':
                switching_time = swtime_sampler.sample()
                episode, switching_time, expert_idx, values_at_switching_state,uncertainty_max,uncertainty_min = roll_in_and_out_lops(
                    env,
                    learner.act,
                    [functools.partial(expert.act, mode=Args.deterministic_experts) for expert in experts],
                    active_state_explorer=ActiveStateExplorer(value_fns=[expert.vfn for expert in experts], sigma=ase_sigma, uncertainty="std"),
                    active_policy_selector=ActivePolicySelector(value_fns=[expert.vfn for expert in experts]),
                    max_episode_len=max_episode_len,
                    switch_time=None,
                    # switching_state_callback=switching_state_callback
                )
            elif Args.ase_uncertainty == 'value_max_gap':
                switching_time = swtime_sampler.sample()
                episode, switching_time, expert_idx, values_at_switching_state,uncertainty_max,uncertainty_min = roll_in_and_out_lops(
                    env,
                    learner.act,
                    [functools.partial(expert.act, mode=Args.deterministic_experts) for expert in experts],
                    active_state_explorer=ActiveStateExplorer(value_fns=[expert.vfn for expert in experts], sigma=ase_sigma, uncertainty="max_gap"),
                    active_policy_selector=ActivePolicySelector(value_fns=[expert.vfn for expert in experts]),
                    max_episode_len=max_episode_len,
                    switch_time=None,
                    # switching_state_callback=switching_state_callback
                )
            elif Args.ase_uncertainty == 'next_state_mean':
                switching_time = swtime_sampler.sample()
                episode, switching_time, expert_idx, values_at_switching_state,uncertainty_max,uncertainty_min = roll_in_and_out_lops(
                    env,
                    learner.act,
                    [functools.partial(expert.act, mode=Args.deterministic_experts) for expert in experts],
                    active_state_explorer=NextStatePredActiveStateExplorer(value_fns=[expert.vfn for expert in experts], state_pred_ensemble=StatePredEnsembleNewStateDetector, all_episodes=expert2episodes, sigma=ase_sigma),
                    active_policy_selector=ActivePolicySelector(value_fns=[expert.vfn for expert in experts]),
                    max_episode_len=max_episode_len,
                    switch_time=None,
                    # switching_state_callback=switching_state_callback
                )
            elif Args.ase_uncertainty == 'euclidean_mean':
                switching_time = swtime_sampler.sample()
                episode, switching_time, expert_idx, values_at_switching_state,uncertainty_max,uncertainty_min = roll_in_and_out_lops(
                    env,
                    learner.act,
                    [functools.partial(expert.act, mode=Args.deterministic_experts) for expert in experts],
                    active_state_explorer=EuclainDistanceActiveStateExplorer(value_fns=[expert.vfn for expert in experts], euclain_distance=EuclideanDistance, all_episodes=expert2episodes, sigma=ase_sigma),
                    active_policy_selector=ActivePolicySelector(value_fns=[expert.vfn for expert in experts]),
                    max_episode_len=max_episode_len,
                    switch_time=None,
                    # switching_state_callback=switching_state_callback
                )
            elif Args.ase_uncertainty == 'wasserstein_mean':
                switching_time = swtime_sampler.sample()
                episode, switching_time, expert_idx, values_at_switching_state,uncertainty_max,uncertainty_min = roll_in_and_out_lops(
                    env,
                    learner.act,
                    [functools.partial(expert.act, mode=Args.deterministic_experts) for expert in experts],
                    active_state_explorer=WassersteinDistanceActiveStateExplorer(value_fns=[expert.vfn for expert in experts], ws_distance=WassersteinDistance, all_episodes=expert2episodes, sigma=ase_sigma),
                    active_policy_selector=ActivePolicySelector(value_fns=[expert.vfn for expert in experts]),
                    max_episode_len=max_episode_len,
                    switch_time=None,
                    # switching_state_callback=switching_state_callback
                )
            else:
                print("error")
                exit()
        elif Args.algorithm == 'rpi':
            switching_time = swtime_sampler.sample()
            episode, switching_time, expert_idx, values_at_switching_state,uncertainty_max,uncertainty_min = roll_in_and_out_lops(
                env,
                learner.act,
                [functools.partial(expert.act, mode=Args.deterministic_experts) for expert in experts],
                active_state_explorer=None,
                active_policy_selector=ActivePolicySelectorPlus(value_fns=[expert.vfn for expert in experts],value_learner_fn=learner.vfn,itr=itr,num_train_steps=num_train_steps,expert=Args.aps_expert,learner=Args.aps_learner),
                max_episode_len=max_episode_len,
                switch_time=switching_time,
                # switching_state_callback=switching_state_callback
            )
            uncertainty_min = None
            uncertainty_max = None
        elif Args.algorithm == 'lops-aps' or Args.algorithm == "lops-il" or Args.algorithm == "lops-lambda":
            switching_time = swtime_sampler.sample()
            episode, switching_time, expert_idx, values_at_switching_state,uncertainty_max,uncertainty_min = roll_in_and_out_lops(
                env,
                learner.act,
                [functools.partial(expert.act, mode=Args.deterministic_experts) for expert in experts],
                active_state_explorer=None,
                #NOTE: when we don't have experts in it, how to get data for learner
                active_policy_selector=ActivePolicySelector(value_fns=[expert.vfn for expert in experts]),
                max_episode_len=max_episode_len,
                switch_time=switching_time,
                # switching_state_callback=switching_state_callback
            )
            uncertainty_min = None
            uncertainty_max = None
        elif Args.algorithm == 'mamba':
            if len(experts)==0:
                print("no expert for mamba, error!!!")
                exit()
            
            switching_time = swtime_sampler.sample()
            expert_idx = random.randint(0, len(experts) - 1)
            episode, values_at_switching_state = roll_in_and_out_mamba(env, learner.act, experts[expert_idx].act, experts[expert_idx].vfn, switching_time, max_episode_len)
            uncertainty_min = None
            uncertainty_max = None
        else:
            raise ValueError(f'Unknown algorithm: {Args.algorithm}')


        all_episodes.append(episode)
        print('riro ep len', len(episode))
        switch_times.append(switching_time)
        ro_expert_inds.append(expert_idx)

        log_switch_time.append(switching_time)
        log_expert_traj_len.append(len(episode) - switching_time)

        if values_at_switching_state is not None:
            log_switch_valmeans.append(values_at_switching_state.mean.item())
            log_switch_valstds.append(values_at_switching_state.std.item())


        if uncertainty_min is not None:

            if isinstance(uncertainty_min, torch.Tensor):
                log_uncertainty_min.append(uncertainty_min.cpu().detach().numpy())
                log_uncertainty_max.append(uncertainty_max.cpu().detach().numpy())
            else:
                log_uncertainty_min.append(uncertainty_min)
                log_uncertainty_max.append(uncertainty_max)

            if Args.use_ase_sigma_ratio:
                ase_sigma= np.min(log_uncertainty_min)+ Args.ase_sigma_ratio*(np.max(log_uncertainty_max)-np.min(log_uncertainty_min))
                log_ase_sigma.append(ase_sigma)

        # Just for logging
        if switching_time < len(episode):  # NOTE: 0 <= switching_time <= len(episode)
            # Expert policy is rolled out at least one step.
            log_expidx.append(expert_idx)

        #     traj_expert = episode[switching_time:]

        #     # Aggregate the trajectory after t_e to D^k with importance weights (1 / T_p(t_e))
        #     assert len(traj_expert) > 0
        #     expert_rollouts[expert_idx].append(traj_expert)

        # if switching_time > 0:
        #     # Learner policy is rolled out at least one step.
        #     traj_learner = episode[:switching_time]
        #     assert len(traj_learner) > 0
        #     learner_rollouts.append(traj_learner)

    if return_wandblogs:
        print('switch times', log_switch_time)
        print('selected_expert', log_expidx)
        wandb_logs = ({
            'riro/learner-swtchtime-hist': wandb.Histogram(log_switch_time),
            'riro/learner-swtchtime-mean': np.mean(log_switch_time),
            'riro/learner-swtchtime-min': min(log_switch_time),
            'riro/learner-swtchtime-max': max(log_switch_time),
            # 'riro/expert_traj_len-hist': wandb.Histogram(np_histogram=np.histogram(log_expert_traj_len, bins=500, range=(0, 1000))),
            'riro/expert_traj_len-hist': wandb.Histogram(log_expert_traj_len),
            'riro/expert_traj_len-min': min(log_expert_traj_len),
            'riro/expert_traj_len-mean': np.mean(log_expert_traj_len),
            'riro/expert_traj_len-max': max(log_expert_traj_len),
            'riro/selected_expert': wandb.Histogram(log_expidx),
            **{f'riro/selected_expert-{expidx}': (np.asarray(log_expidx) == expidx).mean() for expidx in range(len(experts))},
        })

        if len(log_switch_valmeans) > 0:
            wandb_logs = {
                'riro/selected_expert_val_mean': np.mean(log_switch_valmeans),
                'riro/selected_expert_val_std': np.mean(log_switch_valstds),
                **wandb_logs
            }
            
        if len(log_ase_sigma) > 0:
            wandb_logs = {
                'riro/ase_sigma': np.mean(log_ase_sigma),
                **wandb_logs
            }

        #track min max uncertainty
        if len(log_uncertainty_max) > 0 and len(log_uncertainty_min) > 0:

            wandb_logs = {
                'riro/uncertianty_max': np.max(log_uncertainty_max),
                'riro/uncertianty_min': np.min(log_uncertainty_min),
                **wandb_logs
            }

        return all_episodes, switch_times, ro_expert_inds, wandb_logs

    # return learner_rollouts, expert_rollouts
    return all_episodes, switch_times, ro_expert_inds



def limit_num_transitions(episodes: List[List[dict]], max_transitions: int):
    """
    Args:
        - episodes: list of episodes!!
    """
    assert isinstance(episodes, list) and isinstance(episodes[0], list), isinstance(episodes[0][0], dict)

    num_trans = 0
    cutoff_ep_idx, num_overrun = None, 0
    for idx, episode in enumerate(reversed(episodes)):
        num_trans += len(episode)
        if num_trans >= max_transitions:
            num_overrun = num_trans - max_transitions
            cutoff_ep_idx = idx
            break

    if cutoff_ep_idx is None or num_overrun is None:
        return episodes

    _num_transitions = len(flatten(episodes))
    print('debug: original transitions:', _num_transitions)
    print('debug: original len(episodes):', len(episodes))
    print('debug: cutoff_ep_idx:', cutoff_ep_idx)
    print('debug: num_overrun:', num_overrun)

    new_episodes = episodes[-cutoff_ep_idx-1:]  # Remove remaining old episodes
    new_episodes2 = deepcopy(new_episodes)
    new_episodes2[0] = new_episodes[0][num_overrun:]  # Remove old trajectories

    _num_transitions = len(flatten(new_episodes2))
    print(f'debug: after clipping: {_num_transitions} where max is {max_transitions}')
    print('debug: after clipping len(new_episodes):', len(new_episodes2))
    assert _num_transitions <= max_transitions
    return new_episodes2

#NOTE??
class Evaluator:
    def __init__(self, make_env, max_episode_len) -> None:
        self.make_env = make_env
        self.max_episode_len = max_episode_len
        self.best_so_far = -np.inf

    def evaluate(self, agent, num_eval_episodes, update_best=False, save_video_num_ep=0):
        from rpi.evaluation import eval_fn
        stats = eval_fn(self.make_env, agent, self.max_episode_len, num_episodes=num_eval_episodes, save_video_num_ep=save_video_num_ep, verbose=True)
        logs = {f'eval/{key}': val for key, val in stats.items() if not key.startswith('_')}

        if update_best:
            self.best_so_far = max(self.best_so_far, np.mean(stats['_returns']), )
            logs = {'eval/best-so-far': self.best_so_far, **logs}

        return logs

    def inspect_value_nn(self, experts: List[Agent], ref_states: List[np.ndarray]) -> dict:
        logs = {}
        for idx, expert in enumerate(experts):
            # expert.vfn.stddev_coef = 10 / (itr+1)  # Update stddev coef
            for ref_step, ref_state in enumerate(ref_states):
                with torch.no_grad():
                    vfn_stats = expert.vfn.forward_stats(to_torch(ref_state).unsqueeze(0), normalize_input=True)

                logs = {
                    f'probe-expert/{idx}-ref_{ref_step:04d}_std': vfn_stats.std,
                    f'probe-expert/{idx}-ref_{ref_step:04d}_mean': vfn_stats.mean,
                    f'probe-expert/{idx}-ref_{ref_step:04d}_means': wandb.Histogram(tolist(vfn_stats.all_means.squeeze())),
                    f'probe-expert/{idx}-ref_{ref_step:04d}_std_from_means': np.std(tolist(vfn_stats.all_means.squeeze())),
                    **logs
                }
        return logs

#NOTE X: main loops
def train_lops(make_env: Callable, evaluator: Evaluator, learner: Agent, experts: List[Agent], swtime_sampler: SwitchingTimeSampler, num_train_steps: int, max_episode_len = 1000, eval_every: int = 1):
    import os
    logger.info('cvd', os.environ['CUDA_VISIBLE_DEVICES'])

    env = make_env()

##############################
    ### Preparation ###
    stddev_baseline = None
    if Args.algorithm != 'pg-gae':
        # Rollout 10 episodes just to get the initial average episode lengths
        ep_lengths = []
        for _ in range(10):
            episode = rollout_single_ep(env, functools.partial(learner.act, mode=Args.deterministic_experts), max_episode_len)
            ep_lengths.append(len(episode))
        swtime_sampler.update_pma(np.mean(ep_lengths))

        # Let's evaluate the experts first!
        mean2stddev = []
        for expert_idx, expert in enumerate(experts):
            logs = evaluator.evaluate(expert, num_eval_episodes=32)
            mean2stddev.append((logs['eval/returns_mean'], logs['eval/returns_std']))
            logs = {f'prep/expert-{expert.name}-{key}': val for key, val in logs.items() if not key.startswith('_')}
            wandb_log({'step': 0, **logs})

        # NOTE: logs.keys == ("returns_std", "returns_mean")
        # TODO: Get the stddev of the best expert policy
        if len(experts)>0:
            stddev_baseline = sorted(mean2stddev, reverse=True)[0][1]
            print('mean2stddev', mean2stddev)
            print('stddev baseline', stddev_baseline)
            wandb_log({'step': 0, 'prep/stddev_baseline': stddev_baseline})
        else:
            print('mean2stddev', None)
            print('stddev baseline', None)
            wandb_log({'step': 0, 'prep/stddev_baseline': None})

    # ref_stateacts = get_ref_stateacts(env, max_episode_len) #NOTE??

    ### Rollout experts and pretrain value functions ###
    if Args.algorithm == 'pg-gae':
        expert2episodes = []
    elif Args.algorithm == 'rpi' and len(experts) ==0:
        expert2episodes = []
    else:
        ## For each expert k, collect data D^k by rolling out pi^k
        logger.info('Collecting expert episodes...')
        expert2episodes = collect_expert_episodes(env, experts, max_episode_len,
                                                  num_rollouts=Args.pret_num_rollouts, tgtval_method=Args.expert_tgtval,
                                                  learner_gamma=learner.gamma, learner_lmd=learner.lambd)
        logger.info('Collecting expert episodes...done')

        # Restrict the size of expert_episodes!!
        for expert_idx, expert_eps in enumerate(expert2episodes):
            expert2episodes[expert_idx] = limit_num_transitions(expert_eps, max_transitions=Args.expert_buffer_size)

        # Expose the transitions to expert and obs normalizer
        _transitions = flatten([flatten(episodes) for expert_idx, episodes in enumerate(expert2episodes)])
        for expert in experts:
            _states = to_torch([tr['state'] for tr in _transitions])
            # expert.obs_normalizer.experience(_states)

        ### Pretraining value functions ###
        # Update value function V^k from D^k  (By a simple Monte Carlo return??)
        # TODO: Change this to use expert_vfns rather than experts
        logger.info('Updating exeperts value functions...')
        if Args.algorithm == 'rpi':
            log_list = update_expert_vfn(experts, expert2episodes, num_epochs=Args.pret_num_epochs, batch_size=Args.batch_size, num_val_iterations=Args.pret_num_val_iterations,
                                     tgtval_method=Args.expert_tgtval, learner_gamma=learner.gamma, learner_lmd=learner.lambd, learner=learner)
        elif Args.algorithm == "lops-il":
            log_list = update_expert_vfn(experts, expert2episodes, num_epochs=Args.pret_num_epochs, batch_size=Args.batch_size, num_val_iterations=Args.pret_num_val_iterations,
                                     tgtval_method=Args.expert_tgtval, learner_gamma=learner.gamma, learner_lmd=0)
        elif Args.algorithm == "lops-lambda":
            # if itr<Args.maxplus_switch_rl_round:
            # print("itr:",itr,",lops-lambda,0,IL")
            log_list = update_expert_vfn(experts, expert2episodes, num_epochs=Args.pret_num_epochs, batch_size=Args.batch_size, num_val_iterations=Args.pret_num_val_iterations,
                                    tgtval_method=Args.expert_tgtval, learner_gamma=learner.gamma, learner_lmd=Args.lmd)
            # else:
            #     print("itr:",itr,",lops-lambda,1,RL")
            #     log_list = update_expert_vfn(experts, expert2episodes, num_epochs=Args.pret_num_epochs, batch_size=Args.batch_size, num_val_iterations=Args.pret_num_val_iterations,
            #                          tgtval_method=Args.expert_tgtval, learner_gamma=learner.gamma, learner_lmd=1)

        else:
            log_list = update_expert_vfn(experts, expert2episodes, num_epochs=Args.pret_num_epochs, batch_size=Args.batch_size, num_val_iterations=Args.pret_num_val_iterations,
                                     tgtval_method=Args.expert_tgtval, learner_gamma=learner.gamma, learner_lmd=learner.lambd)

        logger.info('Updating exeperts value functions...done')
        # NOTE ^ This attaches v_teacher key to each transition

        for expert_idx, logs in enumerate(log_list):
            for log in logs:
                wandb_log({f'pretrain/expert-{key}' if key != 'vi-step' else key: val for key, val in log.items()})

##############################
    ### Main training loop ###
    for itr in range(num_train_steps):
        logger.info(f'LOPS training loop {itr + 1} / {num_train_steps}')
        learner_ep_lens = []

        ##NOTE X: key to track learner policy
        learner_episodes = []

        ## NOTE?? X: evaluate: plot eval section
        ## Evaluate the models
        if itr % eval_every == 0:
            logs = evaluator.evaluate(learner, num_eval_episodes=Args.num_eval_episodes, update_best=True)
            # TODO: if 'frame' key exists in logs, wrap it with wandb.Video
            wandb_log({'step': itr, **logs})

            # # Inspect state predictor and expert vfns
            # if Args.algorithm != 'pg-gae':
            #     logs_vn = evaluator.inspect_value_nn(experts, ref_states=[state for state, act in ref_stateacts])
            #     wandb_log({'step': itr, **logs_vn})

        ## Roll-in and out, and store the episodes to expert_rollouts
        if Args.algorithm != 'pg-gae' and Args.algorithm!= "rpi":

            if Args.use_ase_sigma_coef:
                ase_sigma = Args.ase_sigma_coef * stddev_baseline
            else:
                ase_sigma = Args.ase_sigma

            #NOTE X: RIRO
            all_episodes, switching_times, ro_expert_inds, wandb_logs = roll_in_and_out(
                env, learner, experts, swtime_sampler, Args.num_rollouts // 2, learner.gamma, learner.lambd, max_episode_len, return_wandblogs=True, ase_sigma=ase_sigma,expert2episodes=expert2episodes
            )
            wandb_log({'step': itr, **wandb_logs})

            # Merge newly obtained expert_rollouts to the current expert_rollouts
            for _episode, _sw_time, _ro_exp_idx in zip(all_episodes, switching_times, ro_expert_inds):
                if len(_episode[_sw_time:]) > 0:
                    expert2episodes[_ro_exp_idx].append(_episode[_sw_time:])

                    # Expose new transitions to expert and state-predictor's obs_normalizers
                    _states = to_torch([tr['state'] for tr in _episode[_sw_time:]])

                    if experts[_ro_exp_idx].obs_normalizer is not None:
                        experts[_ro_exp_idx].obs_normalizer.experience(_states)

                # Merge newly obtained leaner_rollouts to the current learner_rollouts
                if len(_episode[:_sw_time]) > 0:  # NOTE: _sw_time may be bigger than len(_episode) itself!
                    if Args.use_riro_for_learner_pi != 'none':
                        learner_episodes.append(_episode[:_sw_time])
                        # NOTE: learner's obs_normalizer will be updated right before it's policy update

            # Limit the size of expert rollouts
            for expert_idx, expert_eps in enumerate(expert2episodes):
                expert2episodes[expert_idx] = limit_num_transitions(expert_eps, max_transitions=Args.expert_buffer_size)

            # Update value model V^k from D^k
            if Args.algorithm == 'rpi':
                log_list = update_expert_vfn(experts, expert2episodes, num_epochs=Args.num_epochs, batch_size=Args.batch_size, num_val_iterations=1,
                                             tgtval_method=Args.expert_tgtval, learner_gamma=learner.gamma, learner_lmd=learner.lambd, learner=learner)
            elif Args.algorithm == 'lops-il':
                log_list = update_expert_vfn(experts, expert2episodes, num_epochs=Args.num_epochs, batch_size=Args.batch_size, num_val_iterations=1,
                                             tgtval_method=Args.expert_tgtval, learner_gamma=learner.gamma, learner_lmd=0)
            else:
                log_list = update_expert_vfn(experts, expert2episodes, num_epochs=Args.num_epochs, batch_size=Args.batch_size, num_val_iterations=1,
                                             tgtval_method=Args.expert_tgtval, learner_gamma=learner.gamma, learner_lmd=learner.lambd)
            for expert_idx, logs in enumerate(log_list):
                keys = logs[0].keys()
                _logs = {f'train-expert/{key}': np.mean([log[key] for log in logs]) for key in keys if key != 'vi-step'}
                wandb_log({**_logs, 'step': itr})


        elif Args.algorithm == "rpi":

            if Args.use_ase_sigma_coef:
                ase_sigma = Args.ase_sigma_coef * stddev_baseline
            else:
                ase_sigma = Args.ase_sigma

            #NOTE X: RIRO
            all_episodes, switching_times, ro_expert_inds, wandb_logs = roll_in_and_out(
                env, learner, experts, swtime_sampler, Args.num_rollouts // 2, learner.gamma, learner.lambd, max_episode_len, return_wandblogs=True, ase_sigma=ase_sigma,expert2episodes=expert2episodes, itr=itr,num_train_steps=num_train_steps)
            #NOTE X: makes the steps difference with PPO
            wandb_log({'step': itr, **wandb_logs})

            print("ro_expert_inds",ro_expert_inds)

            # Merge newly obtained expert_rollouts to the current expert_rollouts
            for _episode, _sw_time, _ro_exp_idx in zip(all_episodes, switching_times, ro_expert_inds):
                #NOTE X: consider empty oracle set
                
                if len(_episode[_sw_time:]) > 0 and len(expert2episodes)>0 and len(expert2episodes)!= _ro_exp_idx:
                    print("length(expert2episodes):",len(expert2episodes),", _ro_exp_idx: ",_ro_exp_idx)

                    expert2episodes[_ro_exp_idx].append(_episode[_sw_time:])

                    # Expose new transitions to expert and state-predictor's obs_normalizers
                    _states = to_torch([tr['state'] for tr in _episode[_sw_time:]])

                    if experts[_ro_exp_idx].obs_normalizer is not None:
                        experts[_ro_exp_idx].obs_normalizer.experience(_states)

                #NOTE X: this is roll out by learner
                #NOTE either empty experts or learner appended to the experts
                if len(expert2episodes)==0 or len(expert2episodes)+1 == _ro_exp_idx:
                    if Args.use_riro_for_learner_pi != 'none':
                        learner_episodes.append(_episode[_sw_time:])
                        print("learner_episodes1:: ",len(learner_episodes))

                # Merge newly obtained leaner_rollouts to the current learner_rollouts
                if len(_episode[:_sw_time]) > 0:  # NOTE: _sw_time may be bigger than len(_episode) itself!
                    if Args.use_riro_for_learner_pi != 'none':
                        learner_episodes.append(_episode[:_sw_time])
                        print("learner_episodes2:: ",len(learner_episodes))
                        # NOTE: learner's obs_normalizer will be updated right before it's policy update

            # Limit the size of expert rollouts
            #NOTE??: do we need to limit learner rollouts?
            for expert_idx, expert_eps in enumerate(expert2episodes):
                expert2episodes[expert_idx] = limit_num_transitions(expert_eps, max_transitions=Args.expert_buffer_size)
            
            # Update value model V^k from D^k
            if Args.algorithm == 'rpi':
                log_list = update_expert_vfn(experts, expert2episodes, num_epochs=Args.num_epochs, batch_size=Args.batch_size, num_val_iterations=1,
                                         tgtval_method=Args.expert_tgtval, learner_gamma=learner.gamma, learner_lmd=learner.lambd, learner=learner)

            elif Args.algorithm == 'lops-il':
                log_list = update_expert_vfn(experts, expert2episodes, num_epochs=Args.num_epochs, batch_size=Args.batch_size, num_val_iterations=1,
                                         tgtval_method=Args.expert_tgtval, learner_gamma=learner.gamma, learner_lmd=0)
            else:
                log_list = update_expert_vfn(experts, expert2episodes, num_epochs=Args.num_epochs, batch_size=Args.batch_size, num_val_iterations=1,
                                         tgtval_method=Args.expert_tgtval, learner_gamma=learner.gamma, learner_lmd=learner.lambd)

            for expert_idx, logs in enumerate(log_list):
                keys = logs[0].keys()
                _logs = {f'train-expert/{key}': np.mean([log[key] for log in logs]) for key in keys if key != 'vi-step'}
                wandb_log({**_logs, 'step': itr})

        # Rollout with learner policy to collect data D'^n
        buffer_size = Args.learner_buffer_size
        _episodes, curr_ep = rollout(env, learner.act, max_episode_len, break_cond=functools.partial(should_flush, buffer_size=buffer_size))
        if len(curr_ep) > 0:  # Merge all episodes
            _episodes.append(curr_ep)

        print('roll-in', len(flatten(learner_episodes)))
        print('learner rollout', len(flatten(_episodes)))
        for ep in _episodes:
            learner_episodes.append(ep)

        print('roll-in + learner rollout', len(flatten(learner_episodes)))

        learner.obs_normalizer.experience(to_torch([tr['state'] for tr in flatten(learner_episodes)]))  # Critical: update obs_normalizer right before learner.update

        if Args.algorithm == 'pg-gae':
            _learner_transitions = flatten(learner_episodes)
            _attach_value_to_episodes(learner.vfn, _learner_transitions, obs_normalizer=learner.obs_normalizer)
            _attach_log_prob_to_episodes(learner.pi, _learner_transitions, obs_normalizer=learner.obs_normalizer)

        elif Args.algorithm == "rpi":
            # Use obs_normalizers in MaxValueFn
            _learner_transitions = flatten(learner_episodes)

            ##NOTE X: vfn_aggr is used here.
            _attach_value_to_episodes(functools.partial(learner.vfn_aggr_plus, normalize_input=True), _learner_transitions, obs_normalizer=learner.obs_normalizer,itr=itr,learner_vfn=functools.partial(learner.vfn,normalize_input=True))
            
            _attach_log_prob_to_episodes(learner.pi, _learner_transitions, obs_normalizer=learner.obs_normalizer)

            # Attach experts' logpi for (partial) expert trajectories
            if Args.use_riro_for_learner_pi == 'all':
                assert len(experts) == len(expert2episodes)
                for expert, expert_eps in zip(experts, expert2episodes):
                    expert_eps = deepcopy(expert_eps)  # NOTE: We need deepcopy to make sure not to attach logp to the original expert transitions
                    expert_transitions = flatten(expert_eps)
                    # NOTE BUG: When Args.deterministic_experts is False (i.e., expert transitions are collected in a stochastic manner),
                    # somehow the log probability becomes excessively small (~ -1000) and this causes probability ratio in PPO loss to blow up,
                    # making the loss infinity immediately.
                    _attach_log_prob_to_episodes(expert.pi, expert_transitions, obs_normalizer=expert.obs_normalizer)
                    print(f'exp: {expert} transitions', len(expert_transitions))
                    log_probs = np.asarray([tr['log_prob'] for tr in expert_transitions])
                    print(f'exp transitions log_prob max {log_probs.max()}\tmin {log_probs.min()}\tisnan any {np.isnan(log_probs).any()}')
                    # Merge expert_eps to learner_episodes
                    learner_episodes.extend(expert_eps)
        else:
            # Use obs_normalizers in MaxValueFn
            _learner_transitions = flatten(learner_episodes)

            ##NOTE X: vfn_aggr is used here.
            _attach_value_to_episodes(functools.partial(learner.vfn_aggr, normalize_input=True), _learner_transitions, obs_normalizer=None)
            _attach_log_prob_to_episodes(learner.pi, _learner_transitions, obs_normalizer=learner.obs_normalizer)

            # Attach experts' logpi for (partial) expert trajectories
            if Args.use_riro_for_learner_pi == 'all':
                assert len(experts) == len(expert2episodes)
                for expert, expert_eps in zip(experts, expert2episodes):
                    expert_eps = deepcopy(expert_eps)  # NOTE: We need deepcopy to make sure not to attach logp to the original expert transitions
                    expert_transitions = flatten(expert_eps)

                    # NOTE BUG: When Args.deterministic_experts is False (i.e., expert transitions are collected in a stochastic manner),
                    # somehow the log probability becomes excessively small (~ -1000) and this causes probability ratio in PPO loss to blow up,
                    # making the loss infinity immediately.
                    _attach_log_prob_to_episodes(expert.pi, expert_transitions, obs_normalizer=expert.obs_normalizer)
                    print(f'exp: {expert} transitions', len(expert_transitions))
                    log_probs = np.asarray([tr['log_prob'] for tr in expert_transitions])
                    print(f'exp transitions log_prob max {log_probs.max()}\tmin {log_probs.min()}\tisnan any {np.isnan(log_probs).any()}')

                    # Merge expert_eps to learner_episodes
                    learner_episodes.extend(expert_eps)

        print('all in', len(flatten(learner_episodes)))
        for episode in learner_episodes:
            if Args.algorithm == 'rpi':
                _attach_advantage_and_value_target_to_episode(episode, learner.gamma, learner.lambd,algorithm=Args.algorithm, leaner=True)
            elif Args.algorithm == 'lops-il':
                print("lops-il:lambda:",0)
                _attach_advantage_and_value_target_to_episode(episode, learner.gamma, lambd=0, algorithm=Args.algorithm)
            elif Args.algorithm == 'lops-lambda':
                if itr< Args.maxplus_switch_rl_round:
                    print("lops-lambda:IL,","itr:",itr, ",switch_rl_round:",Args.maxplus_switch_rl_round,",lambda:",Args.lmd)
                    _attach_advantage_and_value_target_to_episode(episode, learner.gamma, lambd=0, algorithm=Args.algorithm)
                else:
                    print("lops-lambda:RL,","itr:",itr, ",switch_rl_round:",Args.maxplus_switch_rl_round,",lambda:",1)
                    _attach_advantage_and_value_target_to_episode(episode, learner.gamma, lambd=1, algorithm=Args.algorithm)

            else:
                _attach_advantage_and_value_target_to_episode(episode, learner.gamma, learner.lambd,algorithm=Args.algorithm)
            learner_ep_lens.append(len(episode))

        learner_transitions = flatten(learner_episodes)
        ## Update learner policy:
        ## Compute the sampled gradient based on D'_n with one-step importance sampling, and update learner
        # This also updates critic if algorithm == "pg-gae"
        #NOTE X: learner update
        logger.info('Updating the learner policy...')
        loss_info = learner.update(learner_transitions, num_epochs=Args.num_epochs, batch_size=Args.batch_size)
        logs = {
            f'learner/{key.replace("/", "-")}': val for key, val in loss_info.items()
        }

        wandb_log({
            'step': itr,
            'train-learner/num_transitions': len(learner_transitions),
            'train-learner/episode_lens': wandb.Histogram(learner_ep_lens),
            'train-learner/episode_lens_mean': np.mean(learner_ep_lens),
            **logs,
        })

        # Update polynomial moving average based on learner episode lengths
        swtime_sampler.update_pma(np.mean(learner_ep_lens))

def get_expert(state_dim, act_dim, policy_head, load_path, device=None, obs_normalizer=None):
    from rpi.agents.ppo import PPOAgent

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    lmd = 0.9 #XL: not used
    gamma = 1. #XL: not used

    ## Construct and load a pretrained expert
    pi = Factory.create_pi(state_dim, act_dim, policy_head=policy_head, initializer=None)
    vfn = Factory.create_vfn(state_dim, initializer=None)

    # NOTE: EmpiricalNormalization stores its _mean and _var with `register_buffer` method, which assures the parameters
    # to be *not* updated by an optimizer
    _obs_normalizer = EmpiricalNormalization(state_dim, clip_threshold=5)
    optimizer = torch.optim.Adam(list(pi.parameters()) + list(vfn.parameters()), lr=3e-4, eps=1e-5)
    _obs_normalizer.to(device)

    agent = PPOAgent(pi, vfn, optimizer=optimizer, obs_normalizer=_obs_normalizer, gamma=gamma, lambd=lmd)
    agent.load(load_path)

    if obs_normalizer is None:
        agent.obs_normalizer = _obs_normalizer
    else:
        # Overwrite obs_normalizer
        agent.obs_normalizer = obs_normalizer

    make_vfn = lambda: Factory.create_vfn(state_dim, mean_and_var=True, initializer=None, activation=nn.ReLU)

    # Define Value Ensemble and Initialize weights
    value_ensemble = ValueEnsemble(make_vfn, Args.num_expert_vfns, std_from_means=Args.std_from_means)
    for vfn in value_ensemble.vfns:
        ortho_init(vfn[0], gain=Args.expert_vfn_gain, zeros_bias=False)
        ortho_init(vfn[2], gain=Args.expert_vfn_gain, zeros_bias=False)
        ortho_init(vfn[4], gain=Args.expert_vfn_gain, zeros_bias=False)
    agent.vfn = value_ensemble

    agent.optimizer = torch.optim.Adam(agent.vfn.parameters(), lr=1e-3)
    agent.to(device)

    return agent


def main():
    import gym
    import gymnasium as gymn
    from rpi.agents.mamba import MambaAgent

    num_train_steps = Args.num_train_steps
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    set_random_seed(Args.seed)

    make_env = lambda *args, **kwargs: _make_env(Args.env_name, *args, **kwargs)  # TEMP
    test_env = make_env()

    if isinstance(test_env.observation_space, (gym.spaces.Box, gymn.spaces.Box)):
        # Continuous state space
        state_dim = test_env.observation_space.low.size
    else:
        # Discrete state space
        state_dim = test_env.observation_space.low.size

    if isinstance(test_env.action_space, gym.spaces.Box):
        # Continuous action space
        act_dim = test_env.action_space.low.size
        policy_head = GaussianHeadWithStateIndependentCovariance(
            action_size=act_dim,
            var_type="diagonal",
            var_func=lambda x: torch.exp(2 * x),  # Parameterize log std
            var_param_init=0,  # log std = 0 => std = 1
        )
    else:
        # Discrete action space (assuming categorical)
        act_dim = test_env.action_space.n
        policy_head = SoftmaxCategoricalHead()

    logger.info('obs_dim', state_dim)
    logger.info('act_dim', act_dim)

    # Set up learner policy
    learner_pi = Factory.create_pi(state_dim, act_dim, policy_head=policy_head)
    obs_normalizer = EmpiricalNormalization(state_dim, clip_threshold=5)
    obs_normalizer.to(device)

    # Loading: experts
    experts = [ExpertAgent(test_env, model_dir, policy=policy) for policy, model_dir in Args.experts_info]

    # Define Value Ensemble and Initialize weights for experts
    make_vfn = lambda: Factory.create_vfn(state_dim, mean_and_var=True, initializer=None, activation=nn.ReLU)
    for expert in experts:
        value_ensemble = ValueEnsemble(make_vfn, Args.num_expert_vfns, std_from_means=Args.std_from_means)
        for _vfn in value_ensemble.vfns:
            ortho_init(_vfn[0], gain=Args.expert_vfn_gain, zeros_bias=False)
            ortho_init(_vfn[2], gain=Args.expert_vfn_gain, zeros_bias=False)
            ortho_init(_vfn[4], gain=Args.expert_vfn_gain, zeros_bias=False)

        # HACK: Set value function and optimizer to each expert
        expert.vfn = value_ensemble
        expert.optimizer = torch.optim.Adam(expert.vfn.parameters(), lr=1e-3)
        expert.to(device)

    ##NOTE??: why this works for learner? 
    learner_vfn = Factory.create_vfn(state_dim)

    # Max value fn over all expert vfns and learner vfn
    max_vfn = MaxValueFn([expert.vfn for expert in experts])
    max_vfn.to(device)


    #learner
    # Define Value Ensemble and Initialize weights for learner
    make_learner_vfn = lambda: Factory.create_vfn(state_dim, mean_and_var=True, initializer=None, activation=nn.ReLU)
    value_ensemble_learner = ValueEnsemble(make_learner_vfn, Args.num_expert_vfns, std_from_means=Args.std_from_means)
    for _vfn in value_ensemble_learner.vfns:
        ortho_init(_vfn[0], gain=Args.expert_vfn_gain, zeros_bias=False)
        ortho_init(_vfn[2], gain=Args.expert_vfn_gain, zeros_bias=False)
        ortho_init(_vfn[4], gain=Args.expert_vfn_gain, zeros_bias=False)

    learner_vfn = value_ensemble_learner
    learner_vfn.to(device)

    max_vfn_plus = MaxValueFnPlus([expert.vfn for expert in experts], learner_vfn,expert=Args.maxplus_expert, learner=Args.maxplus_learner, switch_rl_round=Args.maxplus_switch_rl_round,state_in_distribution=Args.state_in_distribution, num_train_steps=Args.num_train_steps,explore_decay_rate=Args.explore_decay_rate)
    max_vfn_plus.to(device)


    if Args.algorithm == 'pg-gae':
        optimizer = torch.optim.Adam(list(learner_pi.parameters()) + list(learner_vfn.parameters()), lr=1e-3, betas=(0.9, 0.99))
        learner = PPOAgent(learner_pi, learner_vfn, optimizer, obs_normalizer, gamma=Args.gamma, lambd=Args.lmd)

    elif Args.algorithm == 'rpi': 
        optimizer = torch.optim.Adam(list(learner_pi.parameters()) + list(learner_vfn.parameters()), lr=1e-3, betas=(0.9, 0.99))
        learner = RPIAgent(learner_pi, learner_vfn, 
                            max_vfn_plus, optimizer, obs_normalizer, gamma=Args.gamma, lambd=Args.lmd, 
                            use_ppo_loss=Args.use_ppo_loss, max_grad_norm=Args.max_grad_norm)

    else:
        optimizer = torch.optim.Adam(learner_pi.parameters(), lr=1e-3, betas=(0.9, 0.99))
        learner = MambaAgent(learner_pi, 
                             max_vfn, optimizer, obs_normalizer, gamma=Args.gamma, lambd=Args.lmd,
                             use_ppo_loss=Args.use_ppo_loss, max_grad_norm=Args.max_grad_norm)

    # else:  
    #     optimizer = torch.optim.Adam(list(learner_pi.parameters()) + list(learner_vfn.parameters()), lr=1e-3, betas=(0.9, 0.99))
    #     learner = RPIAgent(learner_pi, learner_vfn, 
    #                         max_vfn, optimizer, obs_normalizer, gamma=Args.gamma, lambd=Args.lmd, 
    #                         use_ppo_loss=Args.use_ppo_loss, max_grad_norm=Args.max_grad_norm)
    learner.to(device)

    max_episode_len = Args.max_episode_len
    swtime_sampler = SwitchingTimeSampler(time_limit=max_episode_len, use_moving_ave=(not Args.env_name.startswith('dmc')))
    evaluator = Evaluator(make_env, max_episode_len=max_episode_len)

    train_lops(make_env, evaluator, learner, experts, swtime_sampler, num_train_steps=num_train_steps, max_episode_len=Args.max_episode_len)


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("sweep_file", help="sweep file")
    parser.add_argument("-l", "--line-number", type=int, help="sweep file")
    args = parser.parse_args()

    # Obtain kwargs from Sweep
    from params_proto.hyper import Sweep
    sweep = Sweep(Args).load(args.sweep_file)
    kwargs = list(sweep)[args.line_number]

    print("##parameter list: ",kwargs)
    # Update the parameters
    Args._update(kwargs)

    if Args.algorithm == 'pg-gae':
        # Add (maximum) number of transitions the model would experience if roll-in-and-out is performed
        Args.learner_buffer_size += Args.max_episode_len * (Args.num_rollouts // 2)

    # On slurm, cvd is already set in the shell
    if 'CUDA_VISIBLE_DEVICES' not in os.environ:
        avail_gpus = [1, 3]
        cvd = avail_gpus[args.line_number % len(avail_gpus)]
        os.environ["CUDA_VISIBLE_DEVICES"] = str(cvd)

    sweep_basename = os.path.splitext(os.path.basename(args.sweep_file))[0]
    if '--' in sweep_basename:
        sweep_basename, groupname = sweep_basename.split('--', 1)  # Remove separator!
    else:
        groupname = 'original'

    # Wandb setup
    wandb.login()
    wandb.init(
        # Set the project where this run will be logged
        project=f'alops-{sweep_basename}',
        group=groupname,
        config=vars(Args),
    )

    # Slurm specific
    upload_slurm_logs()

    main()
